import torch
import matplotlib.pyplot as plt

def Compute(x_s, func):
    #x_s = torch.arange(-3, 3.01, 0.02, requires_grad=True)
    # x_s = torch.tensor([0.0], requires_grad=True)
    #print(x_s[0].shape)
    z_s = func(x_s)
    #y_prime_s = torch.zeros(size=x_s.shape)
    '''
    for i in range(x_s.shape[-1]):
        y_s[i] = func(x_s[i])
        #y_s[i].backward(retain_graph=True)
        #y_prime_s[i] = x_s.grad[i]
    '''
    return z_s
    # y_prime = x.grad
    # print('The gradient is', y_prime)

def funcplus(Functions):
    def newfunc(x):
        output = 0
        for function in Functions:
            output += function(x)
        return output
    return newfunc

def funcaverage(Functions):
    def newfunc(x):
        output = 0
        for function in Functions:
            output += function(x)
        output /= len(Functions)
        return output
    return newfunc

def Visualization(x_s, Functions, name):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for id in range(len(Functions)):
        #y_s, y_prime_s = Compute(x_s, function)
        z_s = Compute(x_s, Functions[id])
        #plt.plot(x_s.detach().numpy(), y_s.detach().numpy(), label=function)
        surf = ax.plot_surface(x_s[0].detach().numpy(), x_s[1].detach().numpy(), z_s.detach().numpy(), cmap='viridis', label=id+1)
        surf._facecolors2d = surf._facecolor3d
        surf._edgecolors2d = surf._edgecolor3d
    z_s = Compute(x_s, funcplus(Functions))
    #ax.plot_surface(x_s.detach().numpy(), y_s.detach().numpy(), z_s.detach().numpy(), cmap='viridis', label='sum')
    #plt.plot(x_s.detach().numpy(), y_s.detach().numpy(), label='Sum')
    #plt.plot(x_s.detach().numpy(), y_prime_s.numpy(), label="y'")
    '''
    ax.legend()
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('F(X,Y)')
    '''
    #plt.show()
    plt.savefig(name)
    plt.close()
    '''
    for function in Functions:
        y_s, y_prime_s = Compute(x_s, function)
        plt.plot(x_s.detach().numpy(), y_prime_s.detach().numpy(), label=function)
    #y_s = Compute(x_s, funcplus(Functions))
    #plt.plot(x_s.detach().numpy(), y_s.detach().numpy(), label='Sum')
    #plt.plot(x_s.detach().numpy(), y_prime_s.numpy(), label="y'")
    plt.legend()
    #plt.show()
    plt.savefig("LossFunctionGradients")
    plt.close()
    '''